
from torch.utils.data import Dataset


class Ratings_Dataset(Dataset):


    def __init__(self, rating_events):

        super(Dataset, self).__init__()
        self.user_id = list(rating_events['userId'])
        self.item_id = list(rating_events['itemId'])
        self.rating = list(rating_events['rating'])

    def __len__(self):

        return len(self.user_id)

    def __getitem__(self, idx):

        return self.user_id[idx], self.item_id[idx], self.rating[idx]
